#!/usr/bin/env python3
"""
Followup pipeline: Expanded Country Coverage

Builds on the original Demographics & Capital Flows pipeline (Peters 2026)
with an expanded sample (~108 countries vs. original 69). All outputs are
written to followup/ directories to keep the original paper results intact.

Usage:
    python followup/run_pipeline.py [--step STEP_NAME] [--force]
    python followup/run_pipeline.py --step estimate  # re-estimate on expanded sample
    python followup/run_pipeline.py --original        # run on original 69 countries for comparison

Steps (in order):
    download    - Download all data sources (shared with original)
    demographics - Process UN WPP and construct polynomial variables
    macro       - Assemble macro control variable panel (with IFS fixes)
    rates       - Assemble interest rate panel
    merge       - Merge all panels into unified dataset
    estimate    - Estimate models on expanded sample
    scenarios   - Run projections and counterfactuals
    visualize   - Generate figures and tables
    all         - Run all steps (default)
"""

import sys
import os
import argparse
import pandas as pd
import numpy as np
from pathlib import Path

# Add followup src to path (takes priority over original src)
FOLLOWUP_DIR = Path(__file__).parent
PROJECT_DIR = FOLLOWUP_DIR.parent
sys.path.insert(0, str(FOLLOWUP_DIR))
sys.path.insert(1, str(PROJECT_DIR))

from src.download import download_all
from src.demographics import (
    process_un_wpp, construct_polynomial_variables,
    compute_dependency_ratios, compute_future_oadr
)
from src.macro import assemble_macro_panel, filter_eba_sample
from src.interest_rates import (
    assemble_rate_panel, compute_real_rate_differentials,
    compute_demographic_distance, construct_interaction_terms
)
from src.model import (
    estimate_baseline_model, estimate_extended_model,
    estimate_demographics_only, compare_models,
    country_residual_analysis, extract_demographic_contribution,
    estimate_nonlinearity_tests, estimate_rate_channel_tests,
    project_rate_channel
)
from src.scenarios import (
    project_demographic_contribution, compute_country_profiles,
    china_counterfactual, decompose_residuals, generate_projection_table,
    compute_openness_marginal_effects, compute_openness_scenarios,
    compute_global_efficiency, compute_ge_clearing
)
from src.visualize import (
    plot_age_coefficients, plot_country_demographic_contributions,
    plot_projections, plot_model_comparison, plot_residual_map,
    plot_china_counterfactual, save_regression_table, save_latex_table
)

# Followup directories — all output goes here, not to original paths
PROCESSED_DIR = FOLLOWUP_DIR / "data" / "processed"
OUTPUT_DIR = FOLLOWUP_DIR / "output"

# Also read shared raw/processed data from original when needed
ORIG_PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
ORIG_OUTPUT_DIR = PROJECT_DIR / "output"

# Ensure output dirs exist
(OUTPUT_DIR / "tables").mkdir(parents=True, exist_ok=True)
(OUTPUT_DIR / "figures").mkdir(parents=True, exist_ok=True)
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)

# Global flag: use expanded sample or original 69 countries
USE_EXPANSION = True


def step_download(fred_key=None, force=False):
    """Step 1: Download all data sources."""
    print("\n" + "=" * 70)
    print("STEP 1: DATA DOWNLOAD")
    print("=" * 70)
    results = download_all(fred_api_key=fred_key, force=force)
    return results


def step_demographics(force=False):
    """Step 2: Process demographics and construct polynomial variables."""
    print("\n" + "=" * 70)
    print("STEP 2: DEMOGRAPHIC PROCESSING")
    print("=" * 70)

    # Process raw WPP data into age-group shares
    shares = process_un_wpp()
    if shares is None:
        print("ERROR: Could not process UN WPP data")
        return None, None

    print(f"\nDemographic shares: {shares.shape}")
    print(f"Countries: {shares['iso3'].nunique()}")
    print(f"Years: {shares['year'].min()}-{shares['year'].max()}")

    # Load GDP data for weighting (from WEO if available)
    gdp_df = None
    weo_path = Path("/mnt/c/demographics_capital_flows/multilateral/data/raw/weo_data.csv")
    if weo_path.exists():
        weo = pd.read_csv(weo_path)
        if 'ngdp_usd' in weo.columns:
            gdp_df = weo[['iso3', 'year', 'ngdp_usd']].rename(columns={'ngdp_usd': 'gdp'})
            gdp_df = gdp_df.dropna()

    # Construct polynomial variables
    polys = construct_polynomial_variables(shares, gdp_df=gdp_df, P=3)

    # Compute dependency ratios
    dep_ratios = compute_dependency_ratios(shares)
    dep_ratios.to_csv(PROCESSED_DIR / "dependency_ratios.csv", index=False)

    # Compute future OADR
    future_oadr = compute_future_oadr(shares, horizon=20)
    future_oadr.to_csv(PROCESSED_DIR / "future_oadr.csv", index=False)

    return shares, polys


def step_macro():
    """Step 3: Assemble macro control variable panel."""
    print("\n" + "=" * 70)
    print("STEP 3: MACRO PANEL ASSEMBLY")
    print("=" * 70)
    panel = assemble_macro_panel()
    return panel


def step_rates():
    """Step 4: Assemble interest rate panel."""
    print("\n" + "=" * 70)
    print("STEP 4: INTEREST RATE PANEL")
    print("=" * 70)
    rates = assemble_rate_panel()
    return rates


def step_merge(shares_df=None, polys_df=None, macro_df=None, rates_df=None):
    """Step 5: Merge all panels into unified dataset."""
    print("\n" + "=" * 70)
    print("STEP 5: PANEL MERGE")
    print("=" * 70)

    # Load from files if not passed directly
    if polys_df is None:
        poly_path = PROCESSED_DIR / "demographic_polynomials.csv"
        if poly_path.exists():
            polys_df = pd.read_csv(poly_path)

    if macro_df is None:
        macro_path = PROCESSED_DIR / "macro_panel.csv"
        if macro_path.exists():
            macro_df = pd.read_csv(macro_path)

    if rates_df is None:
        rate_path = PROCESSED_DIR / "interest_rate_panel.csv"
        if rate_path.exists():
            rates_df = pd.read_csv(rate_path)

    if shares_df is None:
        share_path = PROCESSED_DIR / "demographic_shares.csv"
        if share_path.exists():
            shares_df = pd.read_csv(share_path)

    # Start with demographics as base
    if polys_df is not None:
        panel = polys_df.copy()
        print(f"  Demographics base: {panel.shape}")
    else:
        print("ERROR: No demographic data available")
        return None

    # Merge macro
    if macro_df is not None:
        panel = panel.merge(macro_df, on=['iso3', 'year'], how='left')
        print(f"  After macro merge: {panel.shape}")

    # Merge rates
    if rates_df is not None:
        # Compute real rate differentials if we have inflation
        if macro_df is not None and 'inflation' in macro_df.columns:
            rates_df = compute_real_rate_differentials(rates_df, macro_df)
        panel = panel.merge(rates_df, on=['iso3', 'year'], how='left')
        print(f"  After rates merge: {panel.shape}")

    # --- Variable transformations ---

    # Log-transform lending rate: raw lending rates span 0-99,765% due to
    # hyperinflation episodes, making the linear specification meaningless.
    # log(1 + rate/100) maps to continuously compounded rates on a sensible scale.
    if 'lending_rate' in panel.columns:
        panel['log_lending_rate'] = np.log1p(panel['lending_rate'] / 100)
        n_extreme = (panel['lending_rate'] > 100).sum()
        print(f"  Created log_lending_rate ({n_extreme} obs with raw rate > 100%)")

    # Winsorize fiscal_bal_gdp at 1st/99th percentiles to limit influence of
    # extreme outliers (raw range: -557 to +125).
    if 'fiscal_bal_gdp' in panel.columns:
        fb = panel['fiscal_bal_gdp'].dropna()
        p01, p99 = fb.quantile(0.01), fb.quantile(0.99)
        n_clipped = ((panel['fiscal_bal_gdp'] < p01) | (panel['fiscal_bal_gdp'] > p99)).sum()
        panel['fiscal_bal_gdp'] = panel['fiscal_bal_gdp'].clip(lower=p01, upper=p99)
        print(f"  Winsorized fiscal_bal_gdp to [{p01:.1f}, {p99:.1f}] ({n_clipped} obs clipped)")

    # Nonlinearity variables for specification testing
    if 'nfa_gdp_lag' in panel.columns:
        panel['nfa_gdp_lag_sq'] = panel['nfa_gdp_lag'] ** 2
        panel['nfa_positive'] = panel['nfa_gdp_lag'].clip(lower=0)
        panel['nfa_negative'] = panel['nfa_gdp_lag'].clip(upper=0)
        print(f"  Created NFA nonlinearity variables (quadratic, creditor/debtor split)")

    if 'life_expectancy' in panel.columns:
        panel['life_expectancy_sq'] = panel['life_expectancy'] ** 2
        print(f"  Created life_expectancy_sq for convexity test")

    # Merge dependency ratios
    dep_path = PROCESSED_DIR / "dependency_ratios.csv"
    if dep_path.exists():
        dep = pd.read_csv(dep_path)
        panel = panel.merge(dep, on=['iso3', 'year'], how='left')

    # Merge future OADR
    oadr_path = PROCESSED_DIR / "future_oadr.csv"
    if oadr_path.exists():
        oadr = pd.read_csv(oadr_path)
        panel = panel.merge(oadr, on=['iso3', 'year'], how='left')

    # Life expectancy × future OADR interaction
    if 'life_expectancy' in panel.columns and 'oadr_plus20' in panel.columns:
        panel['le_x_future_oadr'] = panel['life_expectancy'] * panel['oadr_plus20']

    # Construct financial openness interactions
    panel = construct_interaction_terms(panel)

    # Compute demographic distance
    if shares_df is not None:
        try:
            demo_dist = compute_demographic_distance(shares_df)
            if demo_dist is not None and len(demo_dist) > 0:
                panel = panel.merge(demo_dist, on=['iso3', 'year'], how='left')
                print(f"  After demo distance merge: {panel.shape}")
        except Exception as e:
            print(f"  Warning: Could not compute demographic distance: {e}")

    # Save full panel
    panel.to_csv(PROCESSED_DIR / "full_panel.csv", index=False)
    print(f"\n  Full panel saved: {panel.shape}")
    print(f"  Countries: {panel['iso3'].nunique()}")
    print(f"  Years: {panel['year'].min()}-{panel['year'].max()}")

    # Summary of variable coverage
    key_vars = ['ca_gdp', 'Z_1', 'Z_2', 'Z_3', 'fiscal_bal_gdp', 'nfa_gdp_lag',
                'log_rel_opw', 'health_exp_gdp', 'kaopen', 'real_bond_10y_diff']
    print("\n  Key variable coverage:")
    for v in key_vars:
        if v in panel.columns:
            n = panel[v].notna().sum()
            nc = panel.loc[panel[v].notna(), 'iso3'].nunique()
            print(f"    {v}: {n:,} obs, {nc} countries")

    return panel


def step_estimate(panel_df=None):
    """Step 6: Estimate models."""
    print("\n" + "=" * 70)
    print("STEP 6: MODEL ESTIMATION")
    print("=" * 70)

    if panel_df is None:
        panel_path = PROCESSED_DIR / "full_panel.csv"
        if panel_path.exists():
            panel_df = pd.read_csv(panel_path)
        else:
            print("ERROR: No panel data found")
            return None

    # Filter to estimation sample: historical data with CA/GDP
    est_sample = panel_df[
        (panel_df['ca_gdp'].notna()) &
        (panel_df['year'] >= 1986) &
        (panel_df['year'] <= 2024)
    ].copy()
    print(f"Estimation sample: {est_sample['iso3'].nunique()} countries, {len(est_sample):,} obs, "
          f"{est_sample['year'].min()}-{est_sample['year'].max()}")

    # EBA sample — expanded includes EU completion + Tier 1 countries
    eba_sample = filter_eba_sample(est_sample, extended=True, expansion=USE_EXPANSION)

    results = {}

    # Model 1: Demographics only
    try:
        m_demo, df_demo = estimate_demographics_only(eba_sample)
        results['Demographics Only'] = (m_demo, df_demo)
    except Exception as e:
        print(f"  Demographics-only model failed: {e}")

    # Model 2: Baseline (demographics + EBA controls)
    try:
        m_base, df_base = estimate_baseline_model(eba_sample)
        results['Baseline (Demo + EBA)'] = (m_base, df_base)
    except Exception as e:
        print(f"  Baseline model failed: {e}")

    # Model 3: Extended (+ interest rates + interactions)
    try:
        m_ext, df_ext = estimate_extended_model(eba_sample)
        results['Extended (+ Rates)'] = (m_ext, df_ext)
    except Exception as e:
        print(f"  Extended model failed: {e}")

    # Nonlinearity tests (NFA squared, creditor/debtor split, life expectancy squared)
    try:
        nl_results, nl_summary = estimate_nonlinearity_tests(eba_sample)
        if len(nl_summary) > 0:
            nl_summary.to_csv(OUTPUT_DIR / "tables" / "nonlinearity_tests.csv", index=False)
        results.update(nl_results)
    except Exception as e:
        print(f"  Nonlinearity tests failed: {e}")

    # Rate channel tests (bond yields, term spread, two-stage Carvalho)
    try:
        rate_results, rate_summary = estimate_rate_channel_tests(eba_sample)
        if len(rate_summary) > 0:
            rate_summary.to_csv(OUTPUT_DIR / "tables" / "rate_channel_tests.csv", index=False)
        results.update({k: v for k, v in rate_results.items()
                        if not isinstance(v, dict)})  # skip two-stage dict
    except Exception as e:
        print(f"  Rate channel tests failed: {e}")

    # Rate channel projections (demographic pressure on bond yields)
    try:
        polys_path = PROCESSED_DIR / "demographic_polynomials.csv"
        if polys_path.exists():
            polys_for_proj = pd.read_csv(polys_path)
            rate_proj = project_rate_channel(eba_sample, polys_for_proj)
    except Exception as e:
        print(f"  Rate channel projections failed: {e}")

    # Compare models
    if results:
        comparison = compare_models(results)
        comparison.to_csv(OUTPUT_DIR / "tables" / "model_comparison.csv", index=False)

    return results


def step_scenarios(results=None, shares_df=None, polys_df=None):
    """Step 7: Run projections and counterfactuals."""
    print("\n" + "=" * 70)
    print("STEP 7: SCENARIOS & PROJECTIONS")
    print("=" * 70)

    # Load polys and panel even if no model results, for openness scenarios
    if polys_df is None:
        poly_path = PROCESSED_DIR / "demographic_polynomials.csv"
        if poly_path.exists():
            polys_df = pd.read_csv(poly_path)

    panel_path = PROCESSED_DIR / "full_panel.csv"
    panel = pd.read_csv(panel_path) if panel_path.exists() else None

    if results is None or not results:
        print("  No model results available — running openness scenarios only")
        # Skip model-dependent projections, jump to openness scenarios
        profiles, alpha, projections, decomp, resid_analysis = None, None, None, None, None
    else:
        # Use baseline model for projections
        model_name = 'Baseline (Demo + EBA)' if 'Baseline (Demo + EBA)' in results else list(results.keys())[0]
        model, df_est = results[model_name]

        # Get feature names from the model object
        if model.feature_names:
            feature_names = model.feature_names
        else:
            demo_vars = ['Z_1', 'Z_2', 'Z_3']
            potential_controls = [
                'fiscal_bal_gdp', 'kaopen', 'expected_growth',
                'nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp', 'life_expectancy',
            ]
            controls = [c for c in potential_controls if c in df_est.columns and df_est[c].notna().sum() > 100]
            feature_names = demo_vars + controls
        print(f"  Feature names ({len(feature_names)}): {feature_names}")
        print(f"  Model has {len(model.beta)} coefficients")

        if panel is None:
            panel = df_est

        # Compute country profiles
        focus_countries = ['CHN', 'IND', 'IDN', 'JPN', 'USA', 'DEU', 'BRA', 'NGA', 'ZAF',
                           'KOR', 'GBR', 'AUS', 'MEX', 'TUR', 'SAU']
        profiles, alpha = compute_country_profiles(model, panel, feature_names,
                                                    countries=focus_countries)
        print(f"\n  Computed profiles for {len(profiles)} countries")
        print(f"  Implied age-group coefficients: {alpha}")

        # Project demographic contributions
        projections = project_demographic_contribution(
            model, polys_df, feature_names,
            start_year=2025, end_year=2060
        )
        print(f"  Projections: {projections.shape}")

        # China counterfactual
        if shares_df is not None:
            cf = china_counterfactual(shares_df)
            if cf is not None:
                print("  Computed China counterfactual age structure")

        # Residual decomposition
        decomp = decompose_residuals(model, df_est, feature_names)
        decomp.to_csv(OUTPUT_DIR / "tables" / "residual_decomposition.csv", index=False)

        # Country residual analysis
        resid_analysis = country_residual_analysis(model, df_est)
        resid_analysis.to_csv(OUTPUT_DIR / "tables" / "country_residuals.csv", index=False)

        # Projection table
        proj_table = generate_projection_table(profiles)
        proj_table.to_csv(OUTPUT_DIR / "tables" / "projection_table.csv", index=False)
        print(f"\n  Projection table:\n{proj_table.to_string(index=False, float_format='%.3f')}")

    # --- Openness scenarios (Item 4.9) ---
    # These use interaction coefficients from the extended model CSV,
    # independent of whether the model was re-estimated in this run.
    coeff_path = OUTPUT_DIR / "tables" / "regression_extended_plus_interactions.csv"
    if coeff_path.exists():
        print("\n  --- Openness Scenarios ---")
        try:
            marginal_effects = compute_openness_marginal_effects(panel, polys_df, coeff_path)
        except Exception as e:
            print(f"  Marginal effects failed: {e}")
            marginal_effects = None

        try:
            opening_df, closing_df = compute_openness_scenarios(panel, polys_df, coeff_path)
        except Exception as e:
            print(f"  Opening/closing scenarios failed: {e}")
            opening_df, closing_df = None, None

        try:
            efficiency_detail, efficiency_summary = compute_global_efficiency(panel, polys_df, coeff_path)
        except Exception as e:
            print(f"  Global efficiency failed: {e}")
            efficiency_detail, efficiency_summary = None, None
    else:
        print("  Skipping openness scenarios: no extended model coefficients found")
        marginal_effects, opening_df, closing_df = None, None, None
        efficiency_detail, efficiency_summary = None, None

    # --- GE Capital Market Clearing (Item 4.7A) ---
    ge_proj, ge_clearing = None, None
    if panel is not None and polys_df is not None:
        try:
            baseline_model = results.get('Baseline (Demo + EBA)', (None, None))[0] if results else None
            baseline_features = baseline_model.feature_names if baseline_model and baseline_model.feature_names else None
            ge_proj, ge_clearing = compute_ge_clearing(
                panel, polys_df,
                baseline_model=baseline_model,
                feature_names=baseline_features
            )
        except Exception as e:
            print(f"  GE clearing failed: {e}")

    return {
        'profiles': profiles,
        'alpha': alpha,
        'projections': projections,
        'decomposition': decomp,
        'country_residuals': resid_analysis,
        'marginal_effects': marginal_effects,
        'opening_scenarios': opening_df,
        'closing_scenarios': closing_df,
        'global_efficiency': efficiency_summary,
    }


def step_visualize(results=None, scenario_results=None):
    """Step 8: Generate figures and tables."""
    print("\n" + "=" * 70)
    print("STEP 8: VISUALIZATION")
    print("=" * 70)

    if scenario_results is not None:
        # Age-group coefficient plot
        alpha = scenario_results.get('alpha')
        if alpha is not None:
            plot_age_coefficients(alpha)

        # Country demographic contributions
        profiles = scenario_results.get('profiles')
        if profiles:
            plot_country_demographic_contributions(profiles)

        # Projections
        projections = scenario_results.get('projections')
        if projections is not None and len(projections) > 0:
            plot_projections(projections)

        # Country residuals
        country_resid = scenario_results.get('country_residuals')
        if country_resid is not None:
            plot_residual_map(country_resid)

    if results:
        # Model comparison
        comparison = compare_models(results)
        plot_model_comparison(comparison)

        # Save regression tables
        for name, (model, df) in results.items():
            if model is None:
                continue
            feat_names = model.feature_names or [f'X{i+1}' for i in range(len(model.beta))]
            safe_name = name.replace(' ', '_').replace('(', '').replace(')', '').replace('+', 'plus').lower()
            save_regression_table(model, feat_names, f"regression_{safe_name}.csv")

    print("\n  All visualizations saved to output/figures/")
    print("  All tables saved to output/tables/")


def main():
    global USE_EXPANSION
    parser = argparse.ArgumentParser(description="Followup: Expanded Country Coverage Pipeline")
    parser.add_argument('--fred-key', type=str, default=None, help='FRED API key')
    parser.add_argument('--step', type=str, default='all',
                        help='Step to run: download, demographics, macro, rates, merge, estimate, scenarios, visualize, all')
    parser.add_argument('--force', action='store_true', help='Force re-download of data')
    parser.add_argument('--original', action='store_true',
                        help='Run on original 69 countries (for comparison)')
    args = parser.parse_args()

    if args.original:
        USE_EXPANSION = False
        print("*** Running on ORIGINAL 69-country sample (comparison mode) ***")
    else:
        print(f"*** Running on EXPANDED sample (~108 countries) ***")
        print(f"*** All outputs go to: {OUTPUT_DIR} ***")

    steps = args.step.split(',') if ',' in args.step else [args.step]

    shares_df = None
    polys_df = None
    macro_df = None
    rates_df = None
    panel_df = None
    results = None
    scenario_results = None

    if 'all' in steps or 'download' in steps:
        step_download(fred_key=args.fred_key, force=args.force)

    if 'all' in steps or 'demographics' in steps:
        shares_df, polys_df = step_demographics()

    if 'all' in steps or 'macro' in steps:
        macro_df = step_macro()

    if 'all' in steps or 'rates' in steps:
        rates_df = step_rates()

    if 'all' in steps or 'merge' in steps:
        panel_df = step_merge(shares_df, polys_df, macro_df, rates_df)

    if 'all' in steps or 'estimate' in steps:
        results = step_estimate(panel_df)

    if 'all' in steps or 'scenarios' in steps:
        scenario_results = step_scenarios(results, shares_df, polys_df)

    if 'all' in steps or 'visualize' in steps:
        step_visualize(results, scenario_results)

    print("\n" + "=" * 70)
    print("PIPELINE COMPLETE")
    print("=" * 70)


if __name__ == "__main__":
    main()
